# coding:utf-8

__author__ = 'carey@akhack.com'

from django.utils.deprecation import MiddlewareMixin
from django.http.response import HttpResponse
from django_redis import get_redis_connection
from hashlib import md5


class RequestBlockMiddlewareMixin(MiddlewareMixin):
    """
    django中间件客户端请求频率限制
    需要在django中间件中添加，添加在AuthenticationMiddleware上面
    需要安装django_redis
    """

    limit = 4  # 单位时间内允许请求次数
    expire = 1  # 限制时间
    cache = "default"  # 获取django cache

    def process_request(self, request):
        print(request.path)
        num = self.set_key(request)
        if num > self.limit:
            return HttpResponse("请求频率过快，请稍后重试", status=503)

    @staticmethod
    def get_ident(request):
        """
        Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
        if present and number of proxies is > 0. If not use all of
        HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
        """
        NUM_PROXIES = 1
        xff = request.META.get('HTTP_X_FORWARDED_FOR')
        remote_addr = request.META.get('REMOTE_ADDR')
        num_proxies = NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr

    def get_md5(self, request):
        """
        获取IP md5值
        :param request:
        :return:
        """
        ip_str = self.get_ident(request)
        ip_md5 = md5()
        ip_md5.update(ip_str.encode("utf-8"))
        return ip_md5.hexdigest()

    def set_key(self, request):
        """
        通过redis lua脚本设置请求请求次数和限制时间
        :param request:
        :return: 限制时间内请求次数
        """
        lua = """
            local current
            current = redis.call("incr",KEYS[1])
            if tonumber(current) == 1 then
                redis.call("expire",KEYS[1],ARGV[1])
            end
            return tonumber(redis.call("get", KEYS[1]))
            """
        key = self.get_md5(request)
        redis_cli = get_redis_connection(self.cache)
        data = redis_cli.eval(lua, 1, key, self.expire, self.limit)
        print(data)
        return data
